package jsonrpc

import (
	"context"
	"encoding/json"
	"fmt"
	"io/ioutil"
	"net/http"
	"os"
	"os/signal"
	"strconv"
	"strings"
	"sync"
	"syscall"
	"time"

	"github.com/google/uuid"
)

const (
	jsonrpcVersion = "2.0"
)

// Ctx 接口，用于向Handler传递Web上下文数据，并提供简化操作完成请求处理
type Ctx interface {
	// Req 返回当前请求的Request
	Req() *http.Request
	// Header 获取头部参数
	Header(string) string
	// Bind 获取当前请求的URL参数
	Bind(interface{}) error
	// Success 成功返回
	Success(interface{}) error
	// Error 错误返回
	Error(int, string, interface{}) error
	Logger() Logger
}

// Logger 定义日志接口
type Logger interface {
	SetBase(base map[string]string)
	Debug(msg string, fields ...interface{})
	Info(msg string, fields ...interface{})
	Warn(msg string, fields ...interface{})
	Error(msg string, fields ...interface{})
}

// ServeMux method过滤器
type ServeMux struct {
	addr            string //监听地址
	methods         map[string]func(Ctx) error
	logger          Logger
	defaultFunction func(http.ResponseWriter, *http.Request)
}

type rpcCtx struct {
	r         *http.Request
	request   *RPCRequest
	isSingle  bool
	responses *RPCResponses
	logger    Logger
}

func (c *rpcCtx) Req() *http.Request {
	return c.r
}

func (c *rpcCtx) Header(key string) string {
	return c.r.Header.Get(key)
}

// Bind 解析参数
func (c *rpcCtx) Bind(toType interface{}) error {
	str, err := json.Marshal(c.request.Params)
	if err != nil {
		return err
	}
	err = json.Unmarshal(str, toType)
	return err
}

func (c *rpcCtx) Logger() Logger {
	return c.logger
}

func (c *rpcCtx) NewClient(endpoint string) RPCClient {
	opts := &RPCClientOpts{
		CustomHeaders: Base(c.r),
	}
	return NewClientWithOpts(endpoint, opts)
}

func (c *rpcCtx) Success(val interface{}) error {
	res := &RPCResponse{
		JSONRPC: jsonrpcVersion,
		Result:  val,
		ID:      c.request.ID,
	}
	*c.responses = append(*c.responses, res)
	return nil
}

// SetError 按格式设置rpc错误
func SetError(code int, message string, data interface{}) RPCError {
	return RPCError{
		Code:    code,
		Message: message,
		Data:    data,
	}
}

func (c *rpcCtx) Error(code int, message string, data interface{}) error {
	res := &RPCResponse{
		JSONRPC: jsonrpcVersion,
		Error: &RPCError{
			Code:    code,
			Message: message,
			Data:    data,
		},
		ID: c.request.ID,
	}
	*c.responses = append(*c.responses, res)
	return nil
}

// Register register functions to be called for specific rpc calls
func (m *ServeMux) Register(pattern string, handler func(Ctx) error) {
	m.methods[pattern] = handler
}

// SetDefaultFunc a function to be called if the request is not a HTTP JSON RPC call
func (m *ServeMux) SetDefaultFunc(def func(http.ResponseWriter, *http.Request)) {
	m.defaultFunction = def
}

// Base 获取http请求的基础信息
func Base(r *http.Request) map[string]string {
	base := make(map[string]string)
	base["X-Platform"] = r.Header.Get("X-Platform")
	if r.Header.Get("X-Request-Time") != "" {
		base["X-Request-Time"] = r.Header.Get("X-Request-Time")
	} else {
		base["X-Request-Time"] = strconv.FormatInt(time.Now().UnixNano(), 10)
	}
	base["X-App-Code"] = r.Header.Get("X-App-Code")
	base["X-App-Version"] = r.Header.Get("X-App-Version")
	base["X-Trace-ID"] = r.Header.Get("X-Trace-ID")
	if r.Header.Get("X-Trace-ID") == "" {
		u4 := uuid.New()
		base["X-Trace-ID"] = u4.String()
	}
	base["X-Real-IP"] = r.Header.Get("X-Real-IP")
	base["X-Client-URI"] = r.Header.Get("X-Client-URI")
	base["X-URI"] = r.URL.String()
	if r.Header.Get("X-Client-URI") == "" {
		base["X-Client-URI"] = base["X-URI"]
	}
	return base
}

func (m *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	base := Base(r)
	m.logger.SetBase(base)
	spanBase := make(map[string]string)
	for k, v := range base {
		spanBase[k] = v
	}
	spanLog := m.logger
	// span日志每个服务在入口时记录时间戳
	spanBase["X-Is-Span-Log"] = "1"
	spanBase["X-Receive-Time"] = strconv.FormatInt(time.Now().UnixNano(), 10)
	spanBase["X-Response-Time"] = spanBase["X-Receive-Time"]
	if r.Method != "POST" {
		if m.defaultFunction != nil {
			m.defaultFunction(w, r)
			// --span日志--
			spanBase["X-Response-Time"] = strconv.FormatInt(time.Now().UnixNano(), 10)
			spanLog.SetBase(spanBase)
			spanLog.Error("HTTP JSON RPC Default Handle Function - Method!=\"POST\"")
			// -----------
			return
		}
		// --span日志--
		spanBase["X-Response-Time"] = strconv.FormatInt(time.Now().UnixNano(), 10)
		spanLog.SetBase(spanBase)
		spanLog.Error("HTTP JSON RPC Handle - Method!=\"POST\"")
		// -----------
		fmt.Println("HTTP JSON RPC Handle - Method!=\"POST\"")
		return
	}
	body, err := ioutil.ReadAll(r.Body)
	if err != nil {
		// --span日志--
		spanBase["X-Response-Time"] = strconv.FormatInt(time.Now().UnixNano(), 10)
		spanLog.SetBase(spanBase)
		spanLog.Error(err.Error())
		// -----------
		fmt.Println(err)
		return
	}
	defer r.Body.Close()
	var requests []RPCRequest
	responses := new(RPCResponses)
	single := false
	if strings.HasPrefix(string(body), "[") == true {
		// 数组，客户端批量调用，执行批量解析
		var rs []RPCRequest
		err = json.Unmarshal(body, &rs)
		if err != nil {
			fmt.Printf("HTTP JSON RPC Batch Handle - json.Unmarshal: %v", err)
			// --span日志--
			spanBase["X-Response-Time"] = strconv.FormatInt(time.Now().UnixNano(), 10)
			spanLog.SetBase(spanBase)
			spanLog.Error(fmt.Sprintf("HTTP JSON RPC Batch Handle - json.Unmarshal: %v", err))
			// -----------
			return
		}
		for _, v := range rs {
			if v.JSONRPC == "2.0" {
				requests = append(requests, v)
				continue
			}
			res := &RPCResponse{
				JSONRPC: jsonrpcVersion,
				Error: &RPCError{
					Code:    -32600,
					Message: "Invalid Request",
					Data:    "The JSON sent is not a valid Request object.",
				},
			}
			*responses = append(*responses, res)
		}
	} else {
		// 单个调用
		single = true
		request := RPCRequest{}
		err = json.Unmarshal(body, &request)
		if err != nil {
			fmt.Printf("HTTP JSON RPC Handle - json.Unmarshal: %v", err)
			// --span日志--
			spanBase["X-Response-Time"] = strconv.FormatInt(time.Now().UnixNano(), 10)
			spanLog.SetBase(spanBase)
			spanLog.Error(fmt.Sprintf("HTTP JSON RPC Handle - json.Unmarshal: %v", err))
			// -----------
			return
		}
		if request.JSONRPC == "2.0" {
			requests = append(requests, request)
		} else {
			res := &RPCResponse{
				JSONRPC: jsonrpcVersion,
				Error: &RPCError{
					Code:    -32600,
					Message: "Invalid Request",
					Data:    "The JSON sent is not a valid Request object.",
				},
			}
			*responses = append(*responses, res)
		}
	}
	for _, item := range requests {
		function, ok := m.methods[item.Method]

		if ok {
			c := &rpcCtx{
				r:         r,
				request:   &item,
				isSingle:  single,
				responses: responses,
				logger:    m.logger,
			}
			err = function(c)
			if err != nil {
				if e, ok := err.(RPCError); ok {
					res := &RPCResponse{
						JSONRPC: jsonrpcVersion,
						Error:   &e,
						ID:      c.request.ID,
					}
					*c.responses = append(*c.responses, res)
				} else {
					c.Error(-32000, "Server error", err.Error())
				}
			}
			continue
		}
		// 无匹配method
		res := &RPCResponse{
			JSONRPC: jsonrpcVersion,
			Error: &RPCError{
				Code:    -32601,
				Message: "Method not found",
				Data:    "The called method was not found on the server",
			},
			ID: item.ID,
		}
		*responses = append(*responses, res)
	}
	var data []byte
	if single == true {
		for _, v := range *responses {
			data, _ = json.Marshal(v)
			break
		}
	} else {
		for _, v := range *responses {
			data, _ = json.Marshal(v)
		}
		data, _ = json.Marshal(*responses)
	}
	// --span日志--
	spanBase["X-Response-Time"] = strconv.FormatInt(time.Now().UnixNano(), 10)
	spanLog.SetBase(spanBase)
	spanLog.Info("成功响应数据", data)
	// -----------
	w.Write(data)
}

// NewServer 启动rpc服务
func NewServer(addr string) *ServeMux {
	mux := &ServeMux{addr: addr}
	mux.methods = make(map[string]func(Ctx) error)
	return mux
}

// AddLogger 增加日志类
func (m *ServeMux) AddLogger(l Logger) *ServeMux {
	m.logger = l
	return m
}

// Run 启用服务
func (m *ServeMux) Run() error {
	fmt.Printf("http: Listen and serve on %s ...\n", m.addr)
	srv := http.Server{
		Addr:    m.addr,
		Handler: m,
	}
	//使用WaitGroup同步Goroutine
	var wg sync.WaitGroup
	exit := make(chan os.Signal)
	//监听 Ctrl+C 信号
	signal.Notify(exit, syscall.SIGINT, syscall.SIGTERM)
	go func() {
		<-exit
		wg.Add(1)
		//使用context控制srv.Shutdown的超时时间
		ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
		defer cancel()
		err := srv.Shutdown(ctx)
		if err != nil {
			fmt.Println(err)
		}
		wg.Done()
	}()
	return srv.ListenAndServe()
}
